{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "_cell_guid": "79c7e3d0-c299-4dcb-8224-4455121ee9b0",
    "_uuid": "d629ff2d2480ee46fbb7e2d37f6b5fab8052498a"
   },
   "outputs": [],
   "source": [
    "import numpy as np \n",
    "import pandas as pd \n",
    "import re\n",
    "\n",
    "import os\n",
    "import logging\n",
    "import gc\n",
    "from pathlib import Path\n",
    "import pickle\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from sklearn.metrics import log_loss\n",
    "from sklearn.model_selection import StratifiedKFold\n",
    "\n",
    "from torch.optim.lr_scheduler import CosineAnnealingLR\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "\n",
    "from pytorch_pretrained_bert import BertTokenizer\n",
    "from pytorch_pretrained_bert.modeling import BertModel\n",
    "\n",
    "'''\n",
    "Fork and eddit from:\n",
    "https://www.kaggle.com/ceshine/pytorch-bert-baseline-public-score-0-54\n",
    "\n",
    "We use this notebook to generate BERT embeddings for two mentions and the gender pronoun.\n",
    "We remove punctuation during data pre-processing at this time.\n",
    "'''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "train_df =  pd.concat([\n",
    "    pd.read_csv(\"gap-test.tsv\", delimiter=\"\\t\"),\n",
    "    pd.read_csv(\"gap-validation.tsv\", delimiter=\"\\t\")\n",
    "], axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "BERT_MODEL = 'bert-large-uncased'\n",
    "tokenizer = BertTokenizer.from_pretrained(BERT_MODEL, never_split = (\"[UNK]\", \"[SEP]\", \"[PAD]\", \"[CLS]\", \"[MASK]\", \"[THISISA]\", \"[THISISB]\", \"[THISISP]\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def insert_tag(row):\n",
    "    \"\"\"\n",
    "    Insert custom tags to help us find the position of A, B, and the pronoun after tokenization.\n",
    "    \"\"\"\n",
    "    to_be_inserted = sorted([(row[\"A-offset\"], \" THISISA \"),(row[\"B-offset\"], \" THISISB \"),(row[\"Pronoun-offset\"], \" THISISP \")], key=lambda x: x[0], reverse=True)  # 从大往小插入这样才不会乱顺序    \n",
    "    text = row[\"Text\"]    \n",
    "    for offset, tag in to_be_inserted:\n",
    "        text = text[:offset] + tag + text[offset:]\n",
    "    return text\n",
    "\n",
    "def clean_and_replace_target_name(row):\n",
    "    '''' \n",
    "    Only alphabet left\n",
    "    replace all target name with fake name\n",
    "    '''\n",
    "    \n",
    "    text = row['TextClean']\n",
    "    text = re.sub(\"[^a-zA-Z]\",\" \",text)  \n",
    "    A = re.sub(\"[^a-zA-Z]\",\" \",row['A'])   \n",
    "    B = re.sub(\"[^a-zA-Z]\",\" \",row['B']) \n",
    "    \n",
    "    # replace names\n",
    "    text = re.sub(str(A), tokenizer.tokenize(A)[0], text)\n",
    "    text = re.sub(str(B), tokenizer.tokenize(B)[0], text)\n",
    "    \n",
    "    text = re.sub(r\"THISISA\", r\"[THISISA]\", text)\n",
    "    text = re.sub(r\"THISISB\", r\"[THISISB]\", text)\n",
    "    text = re.sub(r\"THISISP\", r\"[THISISP]\", text)\n",
    "    \n",
    "    text = re.sub(' +', ' ', text)\n",
    "    return text\n",
    "\n",
    "def generate_text(row):\n",
    "    row.loc['TextClean'] = insert_tag(row)\n",
    "    text = clean_and_replace_target_name(row)\n",
    "    return text"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Tokenize"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer.vocab[\"[THISISA]\"] = -1\n",
    "tokenizer.vocab[\"[THISISB]\"] = -1\n",
    "tokenizer.vocab[\"[THISISP]\"] = -1\n",
    "\n",
    "def tokenize(text, tokenizer):\n",
    "    \"\"\"\n",
    "    Returns a list of tokens and the positions of A, B, and the pronoun.\n",
    "    \"\"\"\n",
    "    entries = {}\n",
    "    final_tokens = []\n",
    "    for token in tokenizer.tokenize(text):\n",
    "        if token in (\"[THISISA]\", \"[THISISB]\", \"[THISISP]\"):\n",
    "            entries[token] = len(final_tokens) + 1\n",
    "            continue\n",
    "        final_tokens.append(token)\n",
    "    return final_tokens, (entries[\"[THISISA]\"], entries[\"[THISISB]\"], entries[\"[THISISP]\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "offsets_lst = []\n",
    "tokens_lst = []\n",
    "for _, row in train_df.iterrows():\n",
    "    text = generate_text(row)\n",
    "    tokens, offsets = tokenize(text, tokenizer)\n",
    "    offsets_lst.append(offsets)\n",
    "    tokens_lst.append(tokenizer.convert_tokens_to_ids([\"[CLS]\"] + tokens + [\"[SEP]\"]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Pad the sequences"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "max((len(x) for x in tokens_lst))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# truncate each row to the size of max_len\n",
    "\n",
    "max_len = 269  \n",
    "tokens = np.zeros((len(tokens_lst), max_len), dtype=np.int64)\n",
    "for i, row in enumerate(tokens_lst):\n",
    "    row = np.array(row[:269])\n",
    "    tokens[i, :len(row)] = row\n",
    "\n",
    "# All sentenses\n",
    "token_tensor = torch.from_numpy(tokens)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Generate Embedding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#import os\n",
    "#os.environ['CUDA_VISIBLE_DEVICES'] = '0, 1'\n",
    "#torch.cuda.set_device(1)\n",
    "bert = BertModel.from_pretrained(BERT_MODEL)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bert_outputs = []\n",
    "with torch.no_grad():\n",
    "    for i in range(len(token_tensor)):\n",
    "        if i % 40 == 0:\n",
    "            print(i)\n",
    "        bert_output, _ =  bert(\n",
    "                    token_tensor[i].unsqueeze(0), \n",
    "                    attention_mask=(token_tensor[i].unsqueeze(0) > 0).long(), \n",
    "                    token_type_ids=None, \n",
    "                    output_all_encoded_layers=False) \n",
    "\n",
    "        bert_outputs.append(bert_output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pickle.dump(offsets_lst, open('offsets_lst.pkl', \"wb\"))\n",
    "pickle.dump(tokens_lst, open('token_lst_wto_padding.pkl', \"wb\"))\n",
    "pickle.dump(bert_outputs, open('bert_outputs.pkl', \"wb\"))\n",
    "\n",
    "#pickle.dump(offsets_lst, open('test_offsets_lst.pkl', \"wb\"))\n",
    "#pickle.dump(tokens_lst, open('test_token_lst_wto_padding.pkl', \"wb\"))\n",
    "#pickle.dump(bert_outputs, open('test_bert_outputs.pkl', \"wb\"))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
